# -*- coding: utf-8 -*-
"""
数据清洗
"""
import os
import re
import pymongo
import numpy as np
import pandas as pd
from scrapy.conf import settings


class Cleaning:

    def __init__(self):
        self.i = 0
        self.MONGO_URI = settings['MONGO_URI']                                              # MongoDB IP地址
        self.MONGO_DB = settings['MONGO_DB']                                                # 库名
        self.MONGO_COLL = settings['MONGO_COLL']                                            # 表名
        self.db = pymongo.MongoClient(self.MONGO_URI)[self.MONGO_DB]                        # 建立与MongoDB的连接
        if 'MONGO_AUTHENTICATE' in settings:                                                # 如果需要验证
            self.db.authenticate(name=settings['MONGO_AUTHENTICATE']['name'],
                                 password=settings['MONGO_AUTHENTICATE']['password'],
                                 source=settings['MONGO_AUTHENTICATE']['source'])

    @staticmethod
    def confirm_company_name() -> set:
        """爬取的时是取的列表页第一个公司
        如果搜索的公司不存在并且名字相似也会取下来
        这类数据要过滤掉"""

        excel_name_list = pd.read_excel('医药公司名称.xlsx'                    # Excel表中 需要爬取的公司名
                                        )['公司名称'].to_list()
        folder_list = [root.split('\\')[-1] for root, dirs, files               # 文件夹中公司名字
                       in os.walk(settings['SOURCE'])][1:]

        intersection_set = set(excel_name_list) & set(folder_list)              # 取交集
        print('爬取到的公司数据', len(intersection_set))
        没有爬到的公司列表 = set(excel_name_list) - intersection_set
        print('没有爬到的公司', len(没有爬到的公司列表))
        for folder_name in folder_list:
            if '药' in folder_name or '医' in folder_name:
                intersection_set.add(folder_name)
        print('添加带关键字的公司后', len(intersection_set))
        return intersection_set

    @staticmethod
    def vertical_tables(html_str: str, none_replace: str = None, del_column: list = None) -> dict:
        """
        解析竖向表格
        :param html_str:       html字符串 就是<tr>标签那一坨 不用太精准 在里面就行
        :param none_replace:   默认None 表里的空值替换为None; 也可以传入字符串''or 'aaaa' or 'bbbb' 空值就替换为字符串
        :param del_column:     list 需要删除列的列表 del_column = [3, 4] 删除第3,4列
        :return:               字典
        """

        table_df_list = pd.read_html(html_str, index_col=False)                 # 读取<tr>标签str为DataFrame列表
        if len(table_df_list) > 1:                                              # 如果列表大于1
            table_df_list = [table_df_list[-1]]                                 # 只取最后一个
        for table_df in table_df_list:                                          # 遍历
            df_none = pd.DataFrame(columns=[])                                  # 创建空 DataFrame
            if del_column:                                                      # 如果需要删除某列
                table_df.drop(del_column, axis=1, inplace=True)

            for index_i, ss in table_df.iterrows():                             # 遍历 DataFrame 每一行为Series
                len_a = len(ss)                                                 # Series 里有几个数据
                dup = ss[ss.duplicated()].count()                               # 统计连续的重复值
                if dup > 1:                                                     # 如果这个值大于1
                    ss.drop_duplicates(keep='first', inplace=True)              # 去重 并保留第一个出现的值
                len_b = len(ss)                                                 # 去重后的Series长度
                len_filling = len_a - len_b                                     # 求出需要填充的个数
                nan_list = [np.nan] * len_filling                               # 拼接需要填充NaN(空)的值列表
                ss = ss.append(pd.Series(nan_list), ignore_index=True)          # NaN(空)值添加进Series
                df_none[index_i] = ss                                           # Series添加到DataFrame(df_none变量)

            if len(df_none) / 2 - len(df_none) // 2 == 0:                       # 判断DataFrame是否是的列是不是偶数
                df = pd.DataFrame([])                                           # 创建空DataFrame
                for count_i in range(int(len(df_none) / 2)):                    # 循环拼接loc的查询列表
                    loc_list = [count_i * 2, count_i * 2 + 1]
                    df_split = df_none.loc[loc_list].reset_index(drop=True)     # 重置index顺序
                    df_split = df_split.dropna(axis=1, how='all')               # 删除整列为NaN的值
                    df_split.columns = df_split.iloc[0].tolist()                # 把第0行转为列名
                    df_split = df_split.drop(0, axis=0, inplace=False)          # 删除第0行
                    df = pd.concat([df, df_split], axis=1)                      # 横向拼接DataFrame

                if none_replace is None:                                        # 如果none_switch 没有传参进来
                    df = df.where(df.notnull(), None)                           # 空值替换成None
                else:                                                           # 如果none_switch 有传参进来
                    df = df.replace(np.nan, none_replace)                       # 空值 替换成 空字符串
                dic_list = df.to_dict('records')                                # DataFrame转字典
                for dic in dic_list:
                    yield dic

            else:                                                               # 如果为奇数列
                ex = Exception("表格为奇数列")                                  # 创建异常对象
                raise ex                                                        # 抛出异常对象

    @staticmethod
    def transverse_tables(html_str: str, name: str, replace_dic: dict = None) -> list:
        """
        解析横向表格
        :param html_str:        html字符串
        :param name:            公司名
        :param replace_dic:     需要批量替换的值 {'替换前','替换后'}
        :return:
        """

        table_df_list = pd.read_html(html_str, index_col=False)                 # 读取<tr>标签str为DataFrame列表
        for table_df in table_df_list:
            for i, r in table_df.iterrows():
                if (i % 2) == 0:                                                # 如果是偶数
                    new_data = table_df.loc[                                    # 取下一行的关联公司字段为新数据
                        i+1, '关联公司'].replace('股权结构', '')
                    table_df.loc[i, '关联公司'] = new_data                      # 让关联公司字段等于下一行的数据
            if replace_dic:
                table_df = table_df.replace(replace_dic, regex=True)            # 批量替换
                table_df.dropna(axis=0, how='any', inplace=True)                # 删除有nan的行
                table_df['name'] = [name]*len(table_df)                         # 添加name 列并赋值
            dic_list = table_df.to_dict('records')                              # DataFrame转字典
            return dic_list

    def to_mongo(self, MONGO_COLL, data):
        """存入MongoDB"""
        pass
        if isinstance(data, list):                                              # 如果是list
            self.db[MONGO_COLL].insert_many(data, ordered=False)                # 批量存入
        if isinstance(data, dict):                                              # 如果是字典
            self.db[MONGO_COLL].insert_one(data)                                # 单条存入

    def clean(self, data: dict) -> list:
        """
        清洗数据
        :param data:        每一条原始数据
        :return:           存入 MongoDB的数据
        """
        self.i += 1
        company_name = list(data.keys())[0]                                     # 公司名
        print('<', company_name, '>', self.i)
        value = data[company_name]                                              # 取出data的值
        for value_i in value:                                                   # 循环取出key
            if value_i == 'main.html':                                          # 如果是主页面的源码
                html_text = value[value_i]

                # 企业简介
                company_profile_list = re.findall(                              # 匹配企业简介html文本
                    r'<div class="data-header">企业简介[\s\S]*?<div class="data-header">',
                    html_text)
                if company_profile_list:
                    company_profile_table_df_list = \
                        self.vertical_tables(company_profile_list[0])           # 清洗表格
                    for cptd in company_profile_table_df_list:
                        cptd['name'] = company_name                             # 添加name字段
                        self.to_mongo(self.MONGO_COLL+'_company_profile', cptd) # 入库

                # 联系信息
                contact_information_list = re.findall(                          # 匹配联系信息html文本
                    r'<span class="data-title">联系信息[\s\S]*?<span class="data-title">',
                    html_text)
                if contact_information_list:
                    contact_information_table_df_list = self.vertical_tables(   # 清洗表格
                        contact_information_list[0])
                    for citd in contact_information_table_df_list:
                        citd['name'] = company_name                             # 添加name字段
                        self.to_mongo(self.MONGO_COLL +                         # 入库
                                      '_contact_information', citd)

                # 工商信息
                business_information_list = re.findall(                         # 匹配工商信息html文本
                    r'<span class="data-title">工商信息'
                    r'[\s\S]*?(<tbody>[\s\S]*?注册资本[\s\S]*?<span class="data-title">)',
                    html_text)
                if business_information_list:
                    business_information = re.sub(                              # 删除广告
                        '<td width="145px"[\s\S]*?</td>','',
                        business_information_list[0])
                    business_information_table_df_list = \
                        self.vertical_tables(business_information, None, [4])   # 清洗表格
                    for bitd in business_information_table_df_list:
                        bitd['name'] = company_name                             # 添加name字段
                        self.to_mongo(self.MONGO_COLL +                         # 入库
                                      '_business_information', bitd)

            elif value_i == 'affiliated.html':                                   # 如果是参股控股页面
                html_text = value[value_i]
                if html_text == '':
                    pass
                else:
                    data_list = self.transverse_tables(                         # 清洗表格
                        html_text, company_name, {r'...\xa0更多': ''})
                    self.to_mongo(self.MONGO_COLL+'_affiliated', data_list)     # 入库


class ReadFolderData:
    """读取文件夹下数据"""

    @staticmethod
    def read(folder_path: str) -> dict:
        """
        :param folder_path:     读取的文件夹路径
        :return:                文件夹中的数据(字典)
                                html的返回里边内容 不是html的返回路径
        """

        count_i = 0

        for root, dirs, files in os.walk(folder_path):                          # 循环文件夹
            if count_i == 0:                                                    # 过滤掉第0次
                count_i += 1
            else:
                folder_name = root.split('\\')[-1]                              # 子文件夹名
                dic = {folder_name: {}}                                         # 定义字典
                for file in files:                                              # 循环子文件夹下文件名
                    file_name = file.split('\\')[-1]                            # 子文件夹下文件名
                    file_path = root + '\\' + file                              # 单个文件的路径
                    suffix = file_name.split('.')[-1]                           # 文件后缀
                    if suffix in 'html':                                        # 如果是html
                        with open(file_path, "r", encoding='utf-8') as f:
                            text = f.read()                                     # 读取里面内容
                    else:                                                       # 如果不是html
                        text = file_path                                        # 返回路径
                    dic[folder_name][file_name] = text                          # 组装字典
                yield dic                                                       # 可迭代对象


def main():
    """主函数"""
    c = Cleaning()
    r = ReadFolderData()
    intersection_set = c.confirm_company_name()                                 # 获取需要的公司名
    data_list = r.read(settings['SOURCE'])                                      # 获取数据的可迭代对象
    for data in data_list:                                                      # 遍历可迭代对象
        for data_i in data:                                                     # 再遍历字典
            if data_i in intersection_set:                                      # 如果键名在set里面
                c.clean(data)                                                   # 交给清洗函数


if __name__ == '__main__':
    main()
    # c = Cleaning()
    # r = ReadFolderData()
    # intersection_set = c.confirm_company_name()  # 获取需要的公司名
